"""Wrapper around the Tencent vector database."""
from __future__ import annotations

import json
import logging
import time
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import guard_import
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


class ConnectionParams:
    """Tencent vector DB Connection params.

    See the following documentation for details:
    https://cloud.tencent.com/document/product/1709/95820

    Attribute:
        url (str) : The access address of the vector database server
            that the client needs to connect to.
        key (str): API key for client to access the vector database server,
            which is used for authentication.
        username (str) : Account for client to access the vector database server.
        timeout (int) : Request Timeout.
    """

    def __init__(self, url: str, key: str, username: str = "root", timeout: int = 10):
        self.url = url
        self.key = key
        self.username = username
        self.timeout = timeout


class IndexParams:
    """Tencent vector DB Index params.

    See the following documentation for details:
    https://cloud.tencent.com/document/product/1709/95826
    """

    def __init__(
        self,
        dimension: int,
        shard: int = 1,
        replicas: int = 2,
        index_type: str = "HNSW",
        metric_type: str = "L2",
        params: Optional[Dict] = None,
    ):
        self.dimension = dimension
        self.shard = shard
        self.replicas = replicas
        self.index_type = index_type
        self.metric_type = metric_type
        self.params = params


class TencentVectorDB(VectorStore):
    """Initialize wrapper around the tencent vector database.

    In order to use this you need to have a database instance.
    See the following documentation for details:
    https://cloud.tencent.com/document/product/1709/94951
    """

    field_id: str = "id"
    field_vector: str = "vector"
    field_text: str = "text"
    field_metadata: str = "metadata"

    def __init__(
        self,
        embedding: Embeddings,
        connection_params: ConnectionParams,
        index_params: IndexParams = IndexParams(128),
        database_name: str = "LangChainDatabase",
        collection_name: str = "LangChainCollection",
        drop_old: Optional[bool] = False,
    ):
        self.document = guard_import("tcvectordb.model.document")
        tcvectordb = guard_import("tcvectordb")
        self.embedding_func = embedding
        self.index_params = index_params
        self.vdb_client = tcvectordb.VectorDBClient(
            url=connection_params.url,
            username=connection_params.username,
            key=connection_params.key,
            timeout=connection_params.timeout,
        )
        db_list = self.vdb_client.list_databases()
        db_exist: bool = False
        for db in db_list:
            if database_name == db.database_name:
                db_exist = True
                break
        if db_exist:
            self.database = self.vdb_client.database(database_name)
        else:
            self.database = self.vdb_client.create_database(database_name)
        try:
            self.collection = self.database.describe_collection(collection_name)
            if drop_old:
                self.database.drop_collection(collection_name)
                self._create_collection(collection_name)
        except tcvectordb.exceptions.VectorDBException:
            self._create_collection(collection_name)

    def _create_collection(self, collection_name: str) -> None:
        enum = guard_import("tcvectordb.model.enum")
        vdb_index = guard_import("tcvectordb.model.index")
        index_type = None
        for k, v in enum.IndexType.__members__.items():
            if k == self.index_params.index_type:
                index_type = v
        if index_type is None:
            raise ValueError("unsupported index_type")
        metric_type = None
        for k, v in enum.MetricType.__members__.items():
            if k == self.index_params.metric_type:
                metric_type = v
        if metric_type is None:
            raise ValueError("unsupported metric_type")
        if self.index_params.params is None:
            params = vdb_index.HNSWParams(m=16, efconstruction=200)
        else:
            params = vdb_index.HNSWParams(
                m=self.index_params.params.get("M", 16),
                efconstruction=self.index_params.params.get("efConstruction", 200),
            )
        index = vdb_index.Index(
            vdb_index.FilterIndex(
                self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
            ),
            vdb_index.VectorIndex(
                self.field_vector,
                self.index_params.dimension,
                index_type,
                metric_type,
                params,
            ),
            vdb_index.FilterIndex(
                self.field_text, enum.FieldType.String, enum.IndexType.FILTER
            ),
            vdb_index.FilterIndex(
                self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
            ),
        )
        self.collection = self.database.create_collection(
            name=collection_name,
            shard=self.index_params.shard,
            replicas=self.index_params.replicas,
            description="Collection for LangChain",
            index=index,
        )

    @property
    def embeddings(self) -> Embeddings:
        return self.embedding_func

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Embeddings,
        metadatas: Optional[List[dict]] = None,
        connection_params: Optional[ConnectionParams] = None,
        index_params: Optional[IndexParams] = None,
        database_name: str = "LangChainDatabase",
        collection_name: str = "LangChainCollection",
        drop_old: Optional[bool] = False,
        **kwargs: Any,
    ) -> TencentVectorDB:
        """Create a collection, indexes it with HNSW, and insert data."""
        if len(texts) == 0:
            raise ValueError("texts is empty")
        if connection_params is None:
            raise ValueError("connection_params is empty")
        try:
            embeddings = embedding.embed_documents(texts[0:1])
        except NotImplementedError:
            embeddings = [embedding.embed_query(texts[0])]
        dimension = len(embeddings[0])
        if index_params is None:
            index_params = IndexParams(dimension=dimension)
        else:
            index_params.dimension = dimension
        vector_db = cls(
            embedding=embedding,
            connection_params=connection_params,
            index_params=index_params,
            database_name=database_name,
            collection_name=collection_name,
            drop_old=drop_old,
        )
        vector_db.add_texts(texts=texts, metadatas=metadatas)
        return vector_db

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        timeout: Optional[int] = None,
        batch_size: int = 1000,
        **kwargs: Any,
    ) -> List[str]:
        """Insert text data into TencentVectorDB."""
        texts = list(texts)
        try:
            embeddings = self.embedding_func.embed_documents(texts)
        except NotImplementedError:
            embeddings = [self.embedding_func.embed_query(x) for x in texts]
        if len(embeddings) == 0:
            logger.debug("Nothing to insert, skipping.")
            return []
        pks: list[str] = []
        total_count = len(embeddings)
        for start in range(0, total_count, batch_size):
            # Grab end index
            docs = []
            end = min(start + batch_size, total_count)
            for id in range(start, end, 1):
                metadata = "{}"
                if metadatas is not None:
                    metadata = json.dumps(metadatas[id])
                doc = self.document.Document(
                    id="{}-{}-{}".format(time.time_ns(), hash(texts[id]), id),
                    vector=embeddings[id],
                    text=texts[id],
                    metadata=metadata,
                )
                docs.append(doc)
                pks.append(str(id))
            self.collection.upsert(docs, timeout)
        return pks

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a similarity search against the query string."""
        res = self.similarity_search_with_score(
            query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
        )
        return [doc for doc, _ in res]

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a search on a query string and return results with score."""
        # Embed the query text.
        embedding = self.embedding_func.embed_query(query)
        res = self.similarity_search_with_score_by_vector(
            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
        )
        return res

    def similarity_search_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a similarity search against the query string."""
        res = self.similarity_search_with_score_by_vector(
            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
        )
        return [doc for doc, _ in res]

    def similarity_search_with_score_by_vector(
        self,
        embedding: List[float],
        k: int = 4,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Perform a search on a query string and return results with score."""
        filter = None if expr is None else self.document.Filter(expr)
        ef = 10 if param is None else param.get("ef", 10)
        res: List[List[Dict]] = self.collection.search(
            vectors=[embedding],
            filter=filter,
            params=self.document.HNSWSearchParams(ef=ef),
            retrieve_vector=False,
            limit=k,
            timeout=timeout,
        )
        # Organize results.
        ret: List[Tuple[Document, float]] = []
        if res is None or len(res) == 0:
            return ret
        for result in res[0]:
            meta = result.get(self.field_metadata)
            if meta is not None:
                meta = json.loads(meta)
            doc = Document(page_content=result.get(self.field_text), metadata=meta)
            pair = (doc, result.get("score", 0.0))
            ret.append(pair)
        return ret

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a search and return results that are reordered by MMR."""
        embedding = self.embedding_func.embed_query(query)
        return self.max_marginal_relevance_search_by_vector(
            embedding=embedding,
            k=k,
            fetch_k=fetch_k,
            lambda_mult=lambda_mult,
            param=param,
            expr=expr,
            timeout=timeout,
            **kwargs,
        )

    def max_marginal_relevance_search_by_vector(
        self,
        embedding: list[float],
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        param: Optional[dict] = None,
        expr: Optional[str] = None,
        timeout: Optional[int] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Perform a search and return results that are reordered by MMR."""
        filter = None if expr is None else self.document.Filter(expr)
        ef = 10 if param is None else param.get("ef", 10)
        res: List[List[Dict]] = self.collection.search(
            vectors=[embedding],
            filter=filter,
            params=self.document.HNSWSearchParams(ef=ef),
            retrieve_vector=True,
            limit=fetch_k,
            timeout=timeout,
        )
        # Organize results.
        documents = []
        ordered_result_embeddings = []
        for result in res[0]:
            meta = result.get(self.field_metadata)
            if meta is not None:
                meta = json.loads(meta)
            doc = Document(page_content=result.get(self.field_text), metadata=meta)
            documents.append(doc)
            ordered_result_embeddings.append(result.get(self.field_vector))
        # Get the new order of results.
        new_ordering = maximal_marginal_relevance(
            np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
        )
        # Reorder the values and return.
        ret = []
        for x in new_ordering:
            # Function can return -1 index
            if x == -1:
                break
            else:
                ret.append(documents[x])
        return ret
